{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4be2e6fa-2187-4617-8433-0db4fb0c099c",
   "metadata": {},
   "source": [
    "# LangChain 核心模块学习：Model I/O\n",
    "\n",
    "`Model I/O` 是 LangChain 为开发者提供的一套面向 LLM 的标准化模型接口，包括模型输入（Prompts）、模型输出（Output Parsers）和模型本身（Models）。\n",
    "\n",
    "- Prompts：模板化、动态选择和管理模型输入\n",
    "- Models：以通用接口调用语言模型\n",
    "- Output Parser：从模型输出中提取信息，并规范化内容\n",
    "\n",
    "![](../images/model_io.jpeg)\r\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2e64b01e-f5ad-4614-b0c3-a140f6bb575a",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: langchain in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (0.0.253)\n",
      "Requirement already satisfied: dataclasses-json<0.6.0,>=0.5.7 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (0.5.7)\n",
      "Requirement already satisfied: numpy<2,>=1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.23.5)\n",
      "Requirement already satisfied: requests<3,>=2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (2.31.0)\n",
      "Requirement already satisfied: langsmith<0.1.0,>=0.0.11 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (0.0.19)\n",
      "Requirement already satisfied: pydantic<2,>=1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.10.8)\n",
      "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (3.8.4)\n",
      "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (4.0.2)\n",
      "Requirement already satisfied: numexpr<3.0.0,>=2.8.4 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (2.8.4)\n",
      "Requirement already satisfied: PyYAML>=5.3 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (6.0)\n",
      "Requirement already satisfied: openapi-schema-pydantic<2.0,>=1.2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.2.4)\n",
      "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (8.2.2)\n",
      "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (2.0.15)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n",
      "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (3.1.0)\n",
      "Requirement already satisfied: typing-inspect>=0.4.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (0.9.0)\n",
      "Requirement already satisfied: marshmallow<4.0.0,>=3.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (3.19.0)\n",
      "Requirement already satisfied: marshmallow-enum<2.0.0,>=1.5.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.6.0,>=0.5.7->langchain) (1.5.1)\n",
      "Requirement already satisfied: typing-extensions>=4.2.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from pydantic<2,>=1->langchain) (4.6.2)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests<3,>=2->langchain) (1.26.16)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests<3,>=2->langchain) (2023.5.7)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests<3,>=2->langchain) (3.4)\n",
      "Requirement already satisfied: greenlet!=0.4.17 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from SQLAlchemy<3,>=1.4->langchain) (2.0.2)\n",
      "Requirement already satisfied: packaging>=17.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (23.1)\n",
      "Requirement already satisfied: mypy-extensions>=0.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from typing-inspect>=0.4.0->dataclasses-json<0.6.0,>=0.5.7->langchain) (1.0.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "# 安装最新版本的 LangChain Python SDK（https://github.com/langchain-ai/langchain）\n",
    "!pip install -U langchain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce4a2474-0b69-4830-85cd-3715c22df304",
   "metadata": {},
   "source": [
    "## 模型抽象 Model\n",
    "\n",
    "- 语言模型(LLMs): LangChain 的核心组件。LangChain并不提供自己的LLMs，而是为与许多不同的LLMs（OpenAI、Cohere、Hugging Face等）进行交互提供了一个标准接口。\n",
    "- 聊天模型(Chat Models): 语言模型的一种变体。虽然聊天模型在内部使用了语言模型，但它们提供的接口略有不同。与其暴露一个“输入文本，输出文本”的API不同，它们提供了一个以“聊天消息”作为输入和输出的接口。\n",
    "\n",
    "（注：对比 OpenAI Completion API和 Chat Completion API）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f14f4cf-8e30-47ab-b8b1-d58a90b5b1c1",
   "metadata": {},
   "source": [
    "## 语言模型（LLMs)\n",
    "\n",
    "类继承关系：\n",
    "\n",
    "```\n",
    "BaseLanguageModel --> BaseLLM --> LLM --> <name>  # Examples: AI21, HuggingFaceHub, OpenAI\n",
    "```\n",
    "\n",
    "主要抽象:\n",
    "\n",
    "```\n",
    "LLMResult, PromptValue,\n",
    "CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,\n",
    "CallbackManager, AsyncCallbackManager,\n",
    "AIMessage, BaseMessage\n",
    "```\n",
    "\n",
    "**API 参考文档：https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.llms**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e060d9e-ded9-4fd8-960f-1addd9c879d1",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "### BaseLanguageModel Class\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/schema/language_model.py**\n",
    "\n",
    "这个基类为语言模型定义了一个接口，该接口允许用户以不同的方式与模型交互（例如通过提示或消息）。`generate_prompt` 是其中的一个主要方法，它接受一系列提示，并返回模型的生成结果。\n",
    "\n",
    "\n",
    "```python\n",
    "# 定义 BaseLanguageModel 抽象基类，它从 Serializable, Runnable 和 ABC 继承\n",
    "class BaseLanguageModel(\n",
    "    Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC\n",
    "):\n",
    "    \"\"\"\n",
    "    与语言模型交互的抽象基类。\n",
    "\n",
    "    所有语言模型的封装器都应从 BaseLanguageModel 继承。\n",
    "\n",
    "    主要提供三种方法：\n",
    "    - generate_prompt: 为一系列的提示值生成语言模型输出。提示值是可以转换为任何语言模型输入格式的模型输入（如字符串或消息）。\n",
    "    - predict: 将单个字符串传递给语言模型并返回字符串预测。\n",
    "    - predict_messages: 将一系列 BaseMessages（对应于单个模型调用）传递给语言模型，并返回 BaseMessage 预测。\n",
    "\n",
    "    每种方法都有对应的异步方法。\n",
    "    \"\"\"\n",
    "\n",
    "    # 定义一个抽象方法 generate_prompt，需要子类进行实现\n",
    "    @abstractmethod\n",
    "    def generate_prompt(\n",
    "        self,\n",
    "        prompts: List[PromptValue],  # 输入提示的列表\n",
    "        stop: Optional[List[str]] = None,  # 生成时的停止词列表\n",
    "        callbacks: Callbacks = None,  # 回调，用于执行例如日志记录或流式处理的额外功能\n",
    "        **kwargs: Any,  # 任意的额外关键字参数，通常会传递给模型提供者的 API 调用\n",
    "    ) -> LLMResult:\n",
    "        \"\"\"\n",
    "        将一系列的提示传递给模型并返回模型的生成。\n",
    "\n",
    "        对于提供批处理 API 的模型，此方法应使用批处理调用。\n",
    "\n",
    "        使用此方法时：\n",
    "            1. 希望利用批处理调用，\n",
    "            2. 需要从模型中获取的输出不仅仅是最顶部生成的值，\n",
    "            3. 构建与底层语言模型类型无关的链（例如，纯文本完成模型与聊天模型）。\n",
    "\n",
    "        参数:\n",
    "            prompts: 提示值的列表。提示值是一个可以转换为与任何语言模型匹配的格式的对象（对于纯文本生成模型为字符串，对于聊天模型为 BaseMessages）。\n",
    "            stop: 生成时使用的停止词。模型输出在这些子字符串的首次出现处截断。\n",
    "            callbacks: 要传递的回调。用于执行额外功能，例如在生成过程中进行日志记录或流式处理。\n",
    "            **kwargs: 任意的额外关键字参数。通常这些会传递给模型提供者的 API 调用。\n",
    "\n",
    "        返回值:\n",
    "            LLMResult，它包含每个输入提示的候选生成列表以及特定于模型提供者的额外输出。\n",
    "        \"\"\"\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d261e629-e2b7-4022-b205-4546f23810cb",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "### BaseLLM Class\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/base.py**\n",
    "\n",
    "这段代码定义了一个名为 BaseLLM 的抽象基类。这个基类的主要目的是提供一个基本的接口来处理大型语言模型 (LLM)。\n",
    "\n",
    "```python\n",
    "# 定义 BaseLLM 抽象基类，它从 BaseLanguageModel[str] 和 ABC（Abstract Base Class）继承\n",
    "class BaseLLM(BaseLanguageModel[str], ABC):\n",
    "    \"\"\"Base LLM abstract interface.\n",
    "    \n",
    "    It should take in a prompt and return a string.\"\"\"\n",
    "\n",
    "    # 定义可选的缓存属性，其初始值为 None\n",
    "    cache: Optional[bool] = None\n",
    "\n",
    "    # 定义 verbose 属性，该属性决定是否打印响应文本\n",
    "    # 默认值使用 _get_verbosity 函数的结果\n",
    "    verbose: bool = Field(default_factory=_get_verbosity)\n",
    "    \"\"\"Whether to print out response text.\"\"\"\n",
    "\n",
    "    # 定义 callbacks 属性，其初始值为 None，并从序列化中排除\n",
    "    callbacks: Callbacks = Field(default=None, exclude=True)\n",
    "\n",
    "    # 定义 callback_manager 属性，其初始值为 None，并从序列化中排除\n",
    "    callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)\n",
    "\n",
    "    # 定义 tags 属性，这些标签会被添加到运行追踪中，其初始值为 None，并从序列化中排除\n",
    "    tags: Optional[List[str]] = Field(default=None, exclude=True)\n",
    "    \"\"\"Tags to add to the run trace.\"\"\"\n",
    "\n",
    "    # 定义 metadata 属性，这些元数据会被添加到运行追踪中，其初始值为 None，并从序列化中排除\n",
    "    metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)\n",
    "    \"\"\"Metadata to add to the run trace.\"\"\"\n",
    "\n",
    "    # 内部类定义了这个 pydantic 对象的配置\n",
    "    class Config:\n",
    "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
    "\n",
    "        # 允许使用任意类型\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "```\n",
    "这个基类使用了 Pydantic 的功能，特别是 Field 方法，用于定义默认值和序列化行为。BaseLLM 的子类需要提供实现具体功能的方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98dcdc04-13a8-4b0c-b67f-39198e48b957",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "### LLM Class\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/base.py**\n",
    "\n",
    "这段代码定义了一个名为 LLM 的类，该类继承自 BaseLLM。这个类的目的是为了为用户提供一个简化的接口来处理LLM（大型语言模型），而不期望用户实现完整的 _generate 方法。\n",
    "\n",
    "```python\n",
    "\n",
    "# 继承自 BaseLLM 的 LLM 类\n",
    "class LLM(BaseLLM):\n",
    "    \"\"\"Base LLM abstract class.\n",
    "\n",
    "    The purpose of this class is to expose a simpler interface for working\n",
    "    with LLMs, rather than expect the user to implement the full _generate method.\n",
    "    \"\"\"\n",
    "\n",
    "    # 使用 @abstractmethod 装饰器定义一个抽象方法，子类需要实现这个方法\n",
    "    @abstractmethod\n",
    "    def _call(\n",
    "        self,\n",
    "        prompt: str,  # 输入提示\n",
    "        stop: Optional[List[str]] = None,  # 停止词列表\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,  # 运行管理器\n",
    "        **kwargs: Any,  # 其他关键字参数\n",
    "    ) -> str:\n",
    "        \"\"\"Run the LLM on the given prompt and input.\"\"\"\n",
    "        # 此方法的实现应在子类中提供\n",
    "\n",
    "    # _generate 方法使用了 _call 方法，用于处理多个提示\n",
    "    def _generate(\n",
    "        self,\n",
    "        prompts: List[str],  # 多个输入提示的列表\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> LLMResult:\n",
    "        \"\"\"Run the LLM on the given prompt and input.\"\"\"\n",
    "        # TODO: 在此处添加缓存逻辑\n",
    "        generations = []  # 用于存储生成的文本\n",
    "        # 检查 _call 方法的签名是否支持 run_manager 参数\n",
    "        new_arg_supported = inspect.signature(self._call).parameters.get(\"run_manager\")\n",
    "        for prompt in prompts:  # 遍历每个提示\n",
    "            # 根据是否支持 run_manager 参数来选择调用方法\n",
    "            text = (\n",
    "                self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)\n",
    "                if new_arg_supported\n",
    "                else self._call(prompt, stop=stop, **kwargs)\n",
    "            )\n",
    "            # 将生成的文本添加到 generations 列表中\n",
    "            generations.append([Generation(text=text)])\n",
    "        # 返回 LLMResult 对象，其中包含 generations 列表\n",
    "        return LLMResult(generations=generations)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9e3c1da-ce94-4d68-bd4a-266c4eebd113",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "### LLMs 已支持模型清单\n",
    "\n",
    "**开发者文档：https://python.langchain.com/docs/integrations/llms/**\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/llms**\n",
    "\n",
    "### 使用 LangChain 调用 OpenAI GPT Completion API\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/llms/openai.py**\n",
    "\n",
    "#### BaseOpenAI Class\n",
    "\n",
    "```python\n",
    "class BaseOpenAI(BaseLLM):\n",
    "    \"\"\"OpenAI 大语言模型的基类。\"\"\"\n",
    "\n",
    "    @property\n",
    "    def lc_secrets(self) -> Dict[str, str]:\n",
    "        return {\"openai_api_key\": \"OPENAI_API_KEY\"}\n",
    "\n",
    "    @property\n",
    "    def lc_serializable(self) -> bool:\n",
    "        return True\n",
    "\n",
    "    client: Any  #: :meta private:\n",
    "    model_name: str = Field(\"text-davinci-003\", alias=\"model\")\n",
    "    \"\"\"使用的模型名。\"\"\"\n",
    "    temperature: float = 0.7\n",
    "    \"\"\"要使用的采样温度。\"\"\"\n",
    "    max_tokens: int = 256\n",
    "    \"\"\"完成中生成的最大令牌数。 \n",
    "    -1表示根据提示和模型的最大上下文大小返回尽可能多的令牌。\"\"\"\n",
    "    top_p: float = 1\n",
    "    \"\"\"在每一步考虑的令牌的总概率质量。\"\"\"\n",
    "    frequency_penalty: float = 0\n",
    "    \"\"\"根据频率惩罚重复的令牌。\"\"\"\n",
    "    presence_penalty: float = 0\n",
    "    \"\"\"惩罚重复的令牌。\"\"\"\n",
    "    n: int = 1\n",
    "    \"\"\"为每个提示生成多少完成。\"\"\"\n",
    "    best_of: int = 1\n",
    "    \"\"\"在服务器端生成best_of完成并返回“最佳”。\"\"\"\n",
    "    model_kwargs: Dict[str, Any] = Field(default_factory=dict)\n",
    "    \"\"\"保存任何未明确指定的`create`调用的有效模型参数。\"\"\"\n",
    "    openai_api_key: Optional[str] = None\n",
    "    openai_api_base: Optional[str] = None\n",
    "    openai_organization: Optional[str] = None\n",
    "    # 支持OpenAI的显式代理\n",
    "    openai_proxy: Optional[str] = None\n",
    "    batch_size: int = 20\n",
    "    \"\"\"传递多个文档以生成时使用的批处理大小。\"\"\"\n",
    "    request_timeout: Optional[Union[float, Tuple[float, float]]] = None\n",
    "    \"\"\"向OpenAI完成API的请求超时。 默认为600秒。\"\"\"\n",
    "    logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)\n",
    "    \"\"\"调整生成特定令牌的概率。\"\"\"\n",
    "    max_retries: int = 6\n",
    "    \"\"\"生成时尝试的最大次数。\"\"\"\n",
    "    streaming: bool = False\n",
    "    \"\"\"是否流式传输结果。\"\"\"\n",
    "    allowed_special: Union[Literal[\"all\"], AbstractSet[str]] = set()\n",
    "    \"\"\"允许的特殊令牌集。\"\"\"\n",
    "    disallowed_special: Union[Literal[\"all\"], Collection[str]] = \"all\"\n",
    "    \"\"\"不允许的特殊令牌集。\"\"\"\n",
    "    tiktoken_model_name: Optional[str] = None\n",
    "    \"\"\"使用此类时传递给tiktoken的模型名。\n",
    "    Tiktoken用于计算文档中的令牌数量以限制它们在某个限制以下。\n",
    "    默认情况下，设置为None时，这将与嵌入模型名称相同。\n",
    "    但是，在某些情况下，您可能希望使用此嵌入类与tiktoken不支持的模型名称。\n",
    "    这可以包括使用Azure嵌入或使用多个模型提供商的情况，这些提供商公开了类似OpenAI的API但模型不同。\n",
    "    在这些情况下，为了避免在调用tiktoken时出错，您可以在此处指定要使用的模型名称。\"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0089c8a5-a859-49f2-bec0-fcd84f2f3b56",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.llms import OpenAI\n",
    "\n",
    "llm = OpenAI(model_name=\"text-davinci-003\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e1c9555-5f84-4ec5-9ef7-210def277c54",
   "metadata": {},
   "source": [
    "对比调用 OpenAI API：\n",
    "\n",
    "```python\n",
    "import openai\n",
    "\n",
    "data = openai.Completion.create(\n",
    "  model=\"text-davinci-003\",\n",
    "  prompt=\"Tell me a Joke\"\n",
    ")\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d17df13f-bc45-45b8-8f9e-5d3c7af71fea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Q: What did the fish say when he hit the wall?\n",
      "A: Dam!\n"
     ]
    }
   ],
   "source": [
    "print(llm(\"Tell me a Joke\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5e98082-0a34-4155-b4d2-120ea2243a02",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "1、一个程序员跟他的经理说：“我需要两个星期的时间来完成这个程序。”经理说：“那很不错，但你能不能在一周内完成呢？”程序员：“我可以尝试，但我不能保证它能正常运行。”\n",
      "\n",
      "2、一个程序员坐在一间办公室里，看着他的电脑，一个同事来到他身边，问：“你在干\n"
     ]
    }
   ],
   "source": [
    "print(llm(\"讲10个给程序员听得笑话\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b2bc0c42-636f-4326-983d-e7a5e3e0e2e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "256"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm.max_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a0745e85-e6a9-44ad-b4e0-412bf289910c",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm.max_tokens = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "338c149f-041f-478e-9337-f3f21871cc24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1024"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm.max_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6b9e62cd-b461-41ed-b9c7-a4915f553c00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "1.Q：为什么程序员都喜欢吃苹果？\n",
      "A：因为它们可以让你的代码变得更甜！\n",
      "\n",
      "2.Q：为什么程序员总是起得很晚？\n",
      "A：因为它们总是忙着debug他们的睡眠！\n",
      "\n",
      "3.Q：为什么程序员的眼睛总是红红的？\n",
      "A：因为他们把夜猫子当作黑客！\n",
      "\n",
      "4.Q：为什么程序员总是喝咖啡？\n",
      "A：因为它们可以让他们更加勤奋！\n",
      "\n",
      "5.Q：为什么程序员喜欢穿牛仔裤？\n",
      "A：因为它们可以让他们更加舒适！\n",
      "\n",
      "6.Q：为什么程序员总是穿着耳机？\n",
      "A：因为它们可以让他们黑客更加潜入！\n",
      "\n",
      "7.Q：为什么程序员总是在玩游戏？\n",
      "A：因为它们可以让他们更好地测试他们的代码！\n",
      "\n",
      "8.Q：为什么程序员总是穿着同一件衣服？\n",
      "A：因为它们可以让他们更加专注！\n",
      "\n",
      "9.Q：为什么程序员总是头发乱？\n",
      "A：因为它们的头脑可以更加聪明！\n",
      "\n",
      "10.Q：为什么程序员只会吃披萨？\n",
      "A：因为它们可以让他们思维更加灵活！\n"
     ]
    }
   ],
   "source": [
    "result = llm(\"讲10个给程序员听得笑话\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8e556c1b-8ae5-41da-a52f-1e448b1cb583",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm.temperature=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "dab5a79d-9a97-46e5-8b28-00a57bbd3bac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "def quick_sort(arr):\n",
      "    if len(arr) <= 1:\n",
      "        return arr\n",
      "    pivot = arr[len(arr) // 2]\n",
      "    left = [x for x in arr if x < pivot]\n",
      "    middle = [x for x in arr if x == pivot]\n",
      "    right = [x for x in arr if x > pivot]\n",
      "    return quick_sort(left) + middle + quick_sort(right)\n"
     ]
    }
   ],
   "source": [
    "result = llm(\"生成可执行的快速排序 Python 代码\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0d7f89d5-d126-45e3-85f0-56184435f226",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用 `exec` 定义 `quick_sort` 函数\n",
    "exec(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "787fe533-10aa-47f5-badd-b5227bd55c34",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 1, 2, 3, 6, 8, 10]\n"
     ]
    }
   ],
   "source": [
    "# 现在你可以调用这个函数了\n",
    "print(quick_sort([3,6,8,10,1,2,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "579bfede-8594-4cdd-9c32-fb2343aa1adf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d65a996-a6c6-47da-b272-01fe70dec10d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "304f6c5b-2f80-41f3-a9ca-0c32c3af9a12",
   "metadata": {},
   "source": [
    "## 聊天模型（Chat Models)\n",
    "\n",
    "类继承关系：\n",
    "\n",
    "```\n",
    "BaseLanguageModel --> BaseChatModel --> <name>  # Examples: ChatOpenAI, ChatGooglePalm\n",
    "```\n",
    "\n",
    "主要抽象：\n",
    "\n",
    "```\n",
    "AIMessage, BaseMessage, HumanMessage\n",
    "```\n",
    "\n",
    "**API 参考文档：https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.chat_models**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3d8773c-7584-44f9-a8f2-02f653b61a5a",
   "metadata": {},
   "source": [
    "### BaseChatModel Class\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py**\n",
    "\n",
    "\n",
    "```python\n",
    "class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):\n",
    "    cache: Optional[bool] = None\n",
    "    \"\"\"是否缓存响应。\"\"\"\n",
    "    verbose: bool = Field(default_factory=_get_verbosity)\n",
    "    \"\"\"是否打印响应文本。\"\"\"\n",
    "    callbacks: Callbacks = Field(default=None, exclude=True)\n",
    "    \"\"\"添加到运行追踪的回调函数。\"\"\"\n",
    "    callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)\n",
    "    \"\"\"添加到运行追踪的回调函数管理器。\"\"\"\n",
    "    tags: Optional[List[str]] = Field(default=None, exclude=True)\n",
    "    \"\"\"添加到运行追踪的标签。\"\"\"\n",
    "    metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)\n",
    "    \"\"\"添加到运行追踪的元数据。\"\"\"\n",
    "\n",
    "    # 需要子类实现的 _generate 抽象方法\n",
    "    @abstractmethod\n",
    "    def _generate(\n",
    "        self,\n",
    "        messages: List[BaseMessage],\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> ChatResult:\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e45a2d17-1ac9-44e2-904a-b0520304c264",
   "metadata": {},
   "source": [
    "### ChatOpenAI Class（调用 Chat Completion API）\n",
    "\n",
    "\n",
    "**代码实现：https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/openai.py**\n",
    "\n",
    "```python\n",
    "class ChatOpenAI(BaseChatModel):\n",
    "    \"\"\"OpenAI Chat大语言模型的包装器。\n",
    "\n",
    "    要使用，您应该已经安装了``openai`` python包，并且\n",
    "    环境变量``OPENAI_API_KEY``已使用您的API密钥进行设置。\n",
    "\n",
    "    即使未在此类上明确保存，也可以传入任何有效的参数\n",
    "    至openai.create调用。\n",
    "    \"\"\"\n",
    "\n",
    "    @property\n",
    "    def lc_secrets(self) -> Dict[str, str]:\n",
    "        return {\"openai_api_key\": \"OPENAI_API_KEY\"}\n",
    "\n",
    "    @property\n",
    "    def lc_serializable(self) -> bool:\n",
    "        return True\n",
    "\n",
    "    client: Any = None  #: :meta private:\n",
    "    model_name: str = Field(default=\"gpt-3.5-turbo\", alias=\"model\")\n",
    "    \"\"\"要使用的模型名。\"\"\"\n",
    "    temperature: float = 0.7\n",
    "    \"\"\"使用的采样温度。\"\"\"\n",
    "    model_kwargs: Dict[str, Any] = Field(default_factory=dict)\n",
    "    \"\"\"保存任何未明确指定的`create`调用的有效模型参数。\"\"\"\n",
    "    openai_api_key: Optional[str] = None\n",
    "    \"\"\"API请求的基础URL路径，\n",
    "    如果不使用代理或服务仿真器，请留空。\"\"\"\n",
    "    openai_api_base: Optional[str] = None\n",
    "    openai_organization: Optional[str] = None\n",
    "    # 支持OpenAI的显式代理\n",
    "    openai_proxy: Optional[str] = None\n",
    "    request_timeout: Optional[Union[float, Tuple[float, float]]] = None\n",
    "    \"\"\"请求OpenAI完成API的超时。默认为600秒。\"\"\"\n",
    "    max_retries: int = 6\n",
    "    \"\"\"生成时尝试的最大次数。\"\"\"\n",
    "    streaming: bool = False\n",
    "    \"\"\"是否流式传输结果。\"\"\"\n",
    "    n: int = 1\n",
    "    \"\"\"为每个提示生成的聊天完成数。\"\"\"\n",
    "    max_tokens: Optional[int] = None\n",
    "    \"\"\"生成的最大令牌数。\"\"\"\n",
    "    tiktoken_model_name: Optional[str] = None\n",
    "    \"\"\"使用此类时传递给tiktoken的模型名称。\n",
    "    Tiktoken用于计算文档中的令牌数以限制\n",
    "    它们在某个限制之下。默认情况下，当设置为None时，这将\n",
    "    与嵌入模型名称相同。但是，在某些情况下，\n",
    "    您可能希望使用此嵌入类，模型名称不\n",
    "    由tiktoken支持。这可能包括使用Azure嵌入或\n",
    "    使用其中之一的多个模型提供商公开类似OpenAI的\n",
    "    API但模型不同。在这些情况下，为了避免在调用tiktoken时出错，\n",
    "    您可以在这里指定要使用的模型名称。\"\"\"\n",
    "\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "46e21bec-5389-4488-a58a-34cca6208ea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "chat_model = ChatOpenAI(model_name=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "816e61a6-9710-4077-823c-27f042d5cd83",
   "metadata": {},
   "source": [
    "对比调用 OpenAI API：\n",
    "\n",
    "```python\n",
    "import openai\n",
    "\n",
    "data = openai.ChatCompletion.create(\n",
    "  model=\"gpt-3.5-turbo\",\n",
    "  messages=[\n",
    "        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "        {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
    "        {\"role\": \"assistant\", \"content\": \"The Los Angeles Dodgers won the World Series in 2020.\"},\n",
    "        {\"role\": \"user\", \"content\": \"Where was it played?\"}\n",
    "    ]\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f811a8e1-f15c-4556-a88a-1bb22a9ac5d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.schema import (\n",
    "    AIMessage,\n",
    "    HumanMessage,\n",
    "    SystemMessage\n",
    ")\n",
    "\n",
    "messages = [SystemMessage(content=\"You are a helpful assistant.\"),\n",
    " HumanMessage(content=\"Who won the world series in 2020?\"),\n",
    " AIMessage(content=\"The Los Angeles Dodgers won the World Series in 2020.\"), \n",
    " HumanMessage(content=\"Where was it played?\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "71a95b7c-c53c-4ccd-836a-591c6aa142ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[SystemMessage(content='You are a helpful assistant.', additional_kwargs={}), HumanMessage(content='Who won the world series in 2020?', additional_kwargs={}, example=False), AIMessage(content='The Los Angeles Dodgers won the World Series in 2020.', additional_kwargs={}, example=False), HumanMessage(content='Where was it played?', additional_kwargs={}, example=False)]\n"
     ]
    }
   ],
   "source": [
    "print(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8bf26981-d8f5-45aa-ac35-e38108013cb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='The 2020 World Series was played at Globe Life Field in Arlington, Texas.', additional_kwargs={}, example=False)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chat_model(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9c03ffcd-c808-4e96-854b-549d7d7ca6f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_result = chat_model(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "63f920a9-c676-4ee0-873a-356ada56c0b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "langchain.schema.messages.AIMessage"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(chat_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d681566-cde1-4ae5-8cd7-f53cf59c3e36",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
